#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File      :   ScConv.py
@Time      :   2024/02/26 20:17:17
@Author    :   CSDN迪菲赫尔曼 
@Version   :   1.0
@Reference :   https://blog.csdn.net/weixin_43694096
@Desc      :   None
"""


import torch
import torch.nn.functional as F
import torch.nn as nn

__all__ = "ScConv"


class GroupBatchnorm2d(nn.Module):
    def __init__(self, c_num: int, group_num: int = 16, eps: float = 1e-10):
        super(GroupBatchnorm2d, self).__init__()
        assert c_num >= group_num
        self.group_num = group_num
        self.gamma = nn.Parameter(torch.randn(c_num, 1, 1))
        self.beta = nn.Parameter(torch.zeros(c_num, 1, 1))
        self.eps = eps

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.group_num, -1)
        mean = x.mean(dim=2, keepdim=True)
        std = x.std(dim=2, keepdim=True)
        x = (x - mean) / (std + self.eps)
        x = x.view(N, C, H, W)
        return x * self.gamma + self.beta


class SRU(nn.Module):
    def __init__(
        self, oup_channels: int, group_num: int = 16, gate_treshold: float = 0.5
    ):
        super().__init__()

        self.gn = GroupBatchnorm2d(oup_channels, group_num=group_num)
        self.gate_treshold = gate_treshold
        self.sigomid = nn.Sigmoid()

    def forward(self, x):
        gn_x = self.gn(x)
        w_gamma = F.softmax(self.gn.gamma, dim=0)
        reweigts = self.sigomid(gn_x * w_gamma)
        # Gate
        info_mask = w_gamma > self.gate_treshold
        noninfo_mask = w_gamma <= self.gate_treshold
        x_1 = info_mask * reweigts * x
        x_2 = noninfo_mask * reweigts * x
        x = self.reconstruct(x_1, x_2)
        return x

    def reconstruct(self, x_1, x_2):
        x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1)
        x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1)
        return torch.cat([x_11 + x_22, x_12 + x_21], dim=1)


class CRU(nn.Module):
    """
    alpha: 0<alpha<1
    """

    def __init__(
        self,
        op_channel: int,
        alpha: float = 1 / 2,
        squeeze_radio: int = 2,
        group_size: int = 2,
        group_kernel_size: int = 3,
    ):
        super().__init__()
        self.up_channel = up_channel = int(alpha * op_channel)
        self.low_channel = low_channel = op_channel - up_channel
        self.squeeze1 = nn.Conv2d(
            up_channel, up_channel // squeeze_radio, kernel_size=1, bias=False
        )
        self.squeeze2 = nn.Conv2d(
            low_channel, low_channel // squeeze_radio, kernel_size=1, bias=False
        )
        # up
        self.GWC = nn.Conv2d(
            up_channel // squeeze_radio,
            op_channel,
            kernel_size=group_kernel_size,
            stride=1,
            padding=group_kernel_size // 2,
            groups=group_size,
        )
        self.PWC1 = nn.Conv2d(
            up_channel // squeeze_radio, op_channel, kernel_size=1, bias=False
        )
        # low
        self.PWC2 = nn.Conv2d(
            low_channel // squeeze_radio,
            op_channel - low_channel // squeeze_radio,
            kernel_size=1,
            bias=False,
        )
        self.advavg = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        # Split
        up, low = torch.split(x, [self.up_channel, self.low_channel], dim=1)
        up, low = self.squeeze1(up), self.squeeze2(low)
        # Transform
        Y1 = self.GWC(up) + self.PWC1(up)
        Y2 = torch.cat([self.PWC2(low), low], dim=1)
        # Fuse
        out = torch.cat([Y1, Y2], dim=1)
        out = F.softmax(self.advavg(out), dim=1) * out
        out1, out2 = torch.split(out, out.size(1) // 2, dim=1)
        return out1 + out2


class ScConv(nn.Module):
    def __init__(
        self,
        op_channel: int,
        group_num: int = 16,
        gate_treshold: float = 0.5,
        alpha: float = 1 / 2,
        squeeze_radio: int = 2,
        group_size: int = 2,
        group_kernel_size: int = 3,
    ):
        super().__init__()
        self.SRU = SRU(op_channel, group_num=group_num, gate_treshold=gate_treshold)
        self.CRU = CRU(
            op_channel,
            alpha=alpha,
            squeeze_radio=squeeze_radio,
            group_size=group_size,
            group_kernel_size=group_kernel_size,
        )

    def forward(self, x):
        x = self.SRU(x)
        x = self.CRU(x)
        return x
